{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "99cea58c-48bc-4af6-8358-df9695659983",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Build your own OpenAI Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "673df1fe-eb6c-46ea-9a73-a96e7ae7942e",
   "metadata": {
    "tags": []
   },
   "source": [
    "With the [new OpenAI API](https://openai.com/blog/function-calling-and-other-api-updates) that supports function calling, it's never been easier to build your own agent!\n",
    "\n",
    "In this notebook tutorial, we showcase how to write your own OpenAI agent in **under 50 lines of code**! It is minimal, yet feature complete (with ability to carry on a conversation and use tools)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54b7bc2e-606f-411a-9490-fcfab9236dfc",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Initial Setup "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23e80e5b-aaee-4f23-b338-7ae62b08141f",
   "metadata": {},
   "source": [
    "Let's start by importing some simple building blocks.  \n",
    "\n",
    "The main thing we need is:\n",
    "1. the OpenAI API (using our own `llama_index` LLM class)\n",
    "2. a place to keep conversation history \n",
    "3. a definition for tools that our agent can use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9d47283b-025e-4874-88ed-76245b22f82e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/suo/miniconda3/envs/llama/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.7) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "from typing import Sequence, List\n",
    "\n",
    "from llama_index.llms import OpenAI, ChatMessage\n",
    "from llama_index.tools import BaseTool, FunctionTool\n",
    "\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fe08eb1-e638-4c00-9103-5c305bfacccf",
   "metadata": {},
   "source": [
    "Let's define some very simple calculator tools for our agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3dd3c4a6-f3e0-46f9-ad3b-7ba57d1bc992",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiple two integers and returns the result integer\"\"\"\n",
    "    return a * b\n",
    "\n",
    "\n",
    "multiply_tool = FunctionTool.from_defaults(fn=multiply)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bfcfb78b-7d4f-48d9-8d4c-ffcded23e7ac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Add two integers and returns the result integer\"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "add_tool = FunctionTool.from_defaults(fn=add)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbcbd5ea-f377-44a0-a492-4568daa8b0b6",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Agent Definition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b737e6c-64eb-4ae6-a8f7-350b1953e612",
   "metadata": {},
   "source": [
    "Now, we define our agent that's capable of holding a conversation and calling tools in **under 50 lines of code**.\n",
    "\n",
    "The meat of the agent logic is in the `chat` method. At a high-level, there are 3 steps:\n",
    "1. Call OpenAI to decide which tool (if any) to call and with what arguments.\n",
    "2. Call the tool with the arguments to obtain an output\n",
    "3. Call OpenAI to synthesize a response from the conversation context and the tool output.\n",
    "\n",
    "The `reset` method simply resets the conversation context, so we can start another conversation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a0e068f7-fd24-4f74-8243-5e6e4840f7a6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class YourOpenAIAgent:\n",
    "    def __init__(\n",
    "        self,\n",
    "        tools: Sequence[BaseTool] = [],\n",
    "        llm: OpenAI = OpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\"),\n",
    "        chat_history: List[ChatMessage] = [],\n",
    "    ) -> None:\n",
    "        self._llm = llm\n",
    "        self._tools = {tool.metadata.name: tool for tool in tools}\n",
    "        self._chat_history = chat_history\n",
    "\n",
    "    def reset(self) -> None:\n",
    "        self._chat_history = []\n",
    "\n",
    "    def chat(self, message: str) -> str:\n",
    "        chat_history = self._chat_history\n",
    "        chat_history.append(ChatMessage(role=\"user\", content=message))\n",
    "        functions = [\n",
    "            tool.metadata.to_openai_function() for _, tool in self._tools.items()\n",
    "        ]\n",
    "\n",
    "        ai_message = self._llm.chat(chat_history, functions=functions).message\n",
    "        chat_history.append(ai_message)\n",
    "\n",
    "        function_call = ai_message.additional_kwargs.get(\"function_call\", None)\n",
    "        if function_call is not None:\n",
    "            function_message = self._call_function(function_call)\n",
    "            chat_history.append(function_message)\n",
    "            ai_message = self._llm.chat(chat_history).message\n",
    "            chat_history.append(ai_message)\n",
    "\n",
    "        return ai_message.content\n",
    "\n",
    "    def _call_function(self, function_call: dict) -> ChatMessage:\n",
    "        tool = self._tools[function_call[\"name\"]]\n",
    "        output = tool(**json.loads(function_call[\"arguments\"]))\n",
    "        return ChatMessage(\n",
    "            name=function_call[\"name\"],\n",
    "            content=str(output),\n",
    "            role=\"function\",\n",
    "            additional_kwargs={\"name\": function_call[\"name\"]},\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbc2cec5-6cc0-4814-92a1-ca0bd237528f",
   "metadata": {},
   "source": [
    "## Let's Try It Out!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "08928f6e-610c-420b-8a7b-7a7042bbd6c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "agent = YourOpenAIAgent(tools=[multiply_tool, add_tool])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d5cefbad-32c4-4273-807a-cc179bae4473",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Hello! How can I assist you today?'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.chat(\"Hi\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b8f7650d-57b8-4ef4-b19d-651281ddb1be",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The product of 2123 multiplied by 215123 is 456,706,129.'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent.chat(\"What is 2123 * 215123\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "707d30b8-6405-4187-a9ed-6146dcc42167",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Our (Slightly Better) `OpenAIAgent` Implementation "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "798ca3fd-6711-4c0c-a853-d868dd14b484",
   "metadata": {},
   "source": [
    "We provide a (slightly better) `OpenAIAgent` implementation in LlamaIndex, which you can directly use as follows.  \n",
    "\n",
    "In comparison to the simplified version above:\n",
    "* it implements the `BaseChatEngine` and `BaseQueryEngine` interface, so you can more seamlessly use it in the LlamaIndex framework. \n",
    "* it supports multiple function calls per conversation turn\n",
    "* it supports streaming\n",
    "* it supports async endpoints\n",
    "* it supports callback and tracing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "38ab3938-1138-43ea-b085-f430b42f5377",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llama_index.agent import OpenAIAgent\n",
    "from llama_index.llms import OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d852ece7-e5a1-4368-9d59-c7014e0b5b4d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "llm = OpenAI(model=\"gpt-3.5-turbo-0613\")\n",
    "agent = OpenAIAgent.from_tools([multiply_tool, add_tool], llm=llm, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "500cbee4",
   "metadata": {},
   "source": [
    "### Chat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9fd1cad5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Calling Function ===\n",
      "Calling function: multiply with args: {\n",
      "  \"a\": 121,\n",
      "  \"b\": 3\n",
      "}\n",
      "Got output: 363\n",
      "========================\n",
      "=== Calling Function ===\n",
      "Calling function: add with args: {\n",
      "  \"a\": 363,\n",
      "  \"b\": 42\n",
      "}\n",
      "Got output: 405\n",
      "========================\n",
      "(121 * 3) + 42 = 405\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\"What is (121 * 3) + 42?\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb33983c",
   "metadata": {},
   "source": [
    "### Async Chat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1d1fc974",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Calling Function ===\n",
      "Calling function: multiply with args: {\n",
      "  \"a\": 121,\n",
      "  \"b\": 3\n",
      "}\n",
      "Got output: 363\n",
      "========================\n",
      "121 * 3 = 363\n"
     ]
    }
   ],
   "source": [
    "response = await agent.achat(\"What is 121 * 3?\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aae035cb",
   "metadata": {},
   "source": [
    "### Streaming Chat\n",
    "Here, every LLM response is returned as a generator. You can stream every incremental step, or only the last response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "14217fb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Calling Function ===\n",
      "Calling function: multiply with args: {\n",
      "  \"a\": 121,\n",
      "  \"b\": 2\n",
      "}\n",
      "Got output: 242\n",
      "========================\n",
      "121 * 2 = 242\n",
      "\n",
      "Once upon a time, in a small village, there was a group of mice who lived happily in a cozy little burrow. The leader of the group was a wise and courageous mouse named Milo. Milo was known for his intelligence and his ability to solve problems.\n",
      "\n",
      "One day, Milo gathered all the mice together and announced that they were facing a shortage of food. The mice were worried because winter was approaching, and they needed to find a way to gather enough food to survive the cold months ahead.\n",
      "\n",
      "Milo came up with a brilliant plan. He divided the group into teams of two, pairing each mouse with another. Each pair was assigned the task of collecting as much food as possible. The mice were excited and motivated to contribute to the group's survival.\n",
      "\n",
      "One pair of mice, named Bella and Max, set out on their mission. They scurried through fields and forests, searching for food. With their combined efforts, they were able to gather twice as much food as they had ever collected before. They found an abundance of nuts, seeds, and grains, enough to fill their tiny paws.\n",
      "\n",
      "As Bella and Max returned to the burrow, they were greeted with cheers and applause from the other mice. The food they had collected would ensure that everyone in the group would have enough to eat during the winter.\n",
      "\n",
      "The success of Bella and Max inspired the other mice to work together and multiply their efforts. They formed more pairs and ventured out into different directions, exploring new territories and discovering hidden food sources. With each pair's contribution, the group's food supply grew exponentially.\n",
      "\n",
      "Thanks to the multiplication of their efforts, the mice not only survived the winter but thrived. They had enough food to share with neighboring animals and even helped other creatures in need.\n",
      "\n",
      "The story of the group of mice taught everyone the power of collaboration and the importance of working together towards a common goal. It showed that by multiplying their efforts, they could overcome any challenge and achieve great things.\n",
      "\n",
      "And so, the mice lived happily ever after, always remembering the lesson they had learned about the power of multiplication and unity."
     ]
    }
   ],
   "source": [
    "agent_stream = agent.stream_chat(\n",
    "    \"What is 121 * 2? Once you have the answer, use that number to write a story about a group of mice.\"\n",
    ")\n",
    "for response in agent_stream:\n",
    "    response_gen = response.response_gen\n",
    "    # NOTE: here, we skip any intermediate steps and wait until the last response\n",
    "    # intermediate steps usually only contain function calls though\n",
    "    # for token in response_gen:\n",
    "    #     print(token, end=\"\")\n",
    "\n",
    "for token in response_gen:\n",
    "    print(token, end=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fac119f",
   "metadata": {},
   "source": [
    "### Async Streaming Chat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "33ea069f-819b-4ec1-a93c-fcbaacb362a1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Calling Function ===\n",
      "Calling function: add with args: {\n",
      "  \"a\": 121,\n",
      "  \"b\": 8\n",
      "}\n",
      "Got output: 129\n",
      "========================\n",
      "121 + 8 = 129\n",
      "\n",
      "Once upon a time, in a lush green meadow, there was a group of mice who lived harmoniously in a cozy burrow. The leader of the group was a wise and compassionate mouse named Oliver. Oliver was known for his kindness and his ability to bring the mice together.\n",
      "\n",
      "One sunny morning, as the mice were going about their daily activities, they noticed a group of birds flying overhead. The birds were migrating to a warmer climate, and they dropped a small bag of seeds as they passed by. The mice were excited to find this unexpected gift and quickly gathered around to examine it.\n",
      "\n",
      "Oliver, being the leader, suggested that they share the seeds among themselves. He believed that by working together, they could make the most of this fortunate event. The mice agreed and decided to divide the seeds equally among all the members of the group.\n",
      "\n",
      "As they distributed the seeds, they realized that there were 129 seeds in total. Each mouse received an equal share of 129 divided by the number of mice in the group. They were delighted to have this additional food source, which would help sustain them during the upcoming winter.\n",
      "\n",
      "The mice decided to plant the seeds in a nearby field, hoping to grow a bountiful harvest. They worked tirelessly, digging small holes and carefully placing the seeds in the ground. They watered the seeds and nurtured them with love and care.\n",
      "\n",
      "Days turned into weeks, and the mice watched with anticipation as tiny sprouts emerged from the soil. The sprouts grew into beautiful plants, bearing fruits and vegetables that the mice had never seen before. The field became a vibrant garden, providing an abundance of food for the entire group.\n",
      "\n",
      "The mice celebrated their successful harvest with a grand feast. They enjoyed the fruits of their labor, grateful for the unity and cooperation that had led to their prosperity. Oliver, with tears of joy in his eyes, thanked each and every mouse for their contribution and reminded them of the power of togetherness.\n",
      "\n",
      "From that day forward, the mice continued to work together, sharing their resources and supporting one another. They thrived in their meadow, living in harmony and spreading their message of unity to other creatures they encountered.\n",
      "\n",
      "The story of the group of mice taught everyone the importance of coming together and sharing resources. It showed that by adding their efforts and working as a team, they could overcome challenges and achieve abundance. The mice became an inspiration to all, reminding everyone of the power of collaboration and the strength that lies in unity."
     ]
    }
   ],
   "source": [
    "chat_gen = agent.astream_chat(\n",
    "    \"What is 121 + 8? Once you have the answer, use that number to write a story about a group of mice.\"\n",
    ")\n",
    "\n",
    "async for response in chat_gen:\n",
    "    response_gen = response.response_gen\n",
    "    # NOTE: here, we skip any intermediate steps and wait until the last response\n",
    "    # intermediate steps usually only contain function calls though\n",
    "    # for token in response_gen:\n",
    "    #     print(token, end=\"\")\n",
    "\n",
    "for token in response_gen:\n",
    "    print(token, end=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fe399c5-6d07-4926-b701-b612efd56b30",
   "metadata": {},
   "source": [
    "### Agent with Personality"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b47c034-f948-4604-a8d8-828b617ea245",
   "metadata": {},
   "source": [
    "You can specify a system prompt to give the agent additional instruction or personality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "bef36d1e-c26e-4b07-b3d0-3b7f314a45f5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llama_index.agent import OpenAIAgent\n",
    "from llama_index.llms import OpenAI\n",
    "from llama_index.prompts.system import SHAKESPEARE_WRITING_ASSISTANT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eba7fa46-1173-42f2-885c-0cc28df1cd2e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "llm = OpenAI(model=\"gpt-3.5-turbo-0613\")\n",
    "\n",
    "agent = OpenAIAgent.from_tools(\n",
    "    [multiply_tool, add_tool],\n",
    "    llm=llm,\n",
    "    verbose=True,\n",
    "    system_prompt=SHAKESPEARE_WRITING_ASSISTANT,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4841778-7008-4b61-afcc-995b6b64e91a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = agent.chat(\"Hi\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a83768-1203-4485-a346-2fa78089afb1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "response = agent.chat(\"Tell me a story\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "096c322b-bf5d-4140-9dcf-341062a4151b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
